{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### EM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 导入numpy库 \n",
    "import numpy as np\n",
    "\n",
    "### EM算法过程函数定义\n",
    "def em(data, thetas, max_iter=30, eps=1e-3):\n",
    "    '''\n",
    "    输入：\n",
    "    data：观测数据\n",
    "    thetas：初始化的估计参数值\n",
    "    max_iter：最大迭代次数\n",
    "    eps：收敛阈值\n",
    "    输出：\n",
    "    thetas：估计参数\n",
    "    '''\n",
    "    # 初始化似然函数值\n",
    "    ll_old = -np.infty\n",
    "    for i in range(max_iter):\n",
    "        ### E步：求隐变量分布\n",
    "        # 对数似然\n",
    "        log_like = np.array([np.sum(data * np.log(theta), axis=1) for theta in thetas])\n",
    "        # 似然\n",
    "        like = np.exp(log_like)\n",
    "        # 求隐变量分布\n",
    "        ws = like/like.sum(0)\n",
    "        # 概率加权\n",
    "        vs = np.array([w[:, None] * data for w in ws])\n",
    "        ### M步：更新参数值\n",
    "        thetas = np.array([v.sum(0)/v.sum() for v in vs])\n",
    "        # 更新似然函数\n",
    "        ll_new = np.sum([w*l for w, l in zip(ws, log_like)])\n",
    "        print(\"Iteration: %d\" % (i+1))\n",
    "        print(\"theta_B = %.2f, theta_C = %.2f, ll = %.2f\" \n",
    "              % (thetas[0,0], thetas[1,0], ll_new))\n",
    "        # 满足迭代条件即退出迭代\n",
    "        if np.abs(ll_new - ll_old) < eps:\n",
    "            break\n",
    "        ll_old = ll_new\n",
    "    return thetas"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iteration: 1\n",
      "theta_B = 0.71, theta_C = 0.58, ll = -32.69\n",
      "Iteration: 2\n",
      "theta_B = 0.75, theta_C = 0.57, ll = -31.26\n",
      "Iteration: 3\n",
      "theta_B = 0.77, theta_C = 0.55, ll = -30.76\n",
      "Iteration: 4\n",
      "theta_B = 0.78, theta_C = 0.53, ll = -30.33\n",
      "Iteration: 5\n",
      "theta_B = 0.79, theta_C = 0.53, ll = -30.07\n",
      "Iteration: 6\n",
      "theta_B = 0.79, theta_C = 0.52, ll = -29.95\n",
      "Iteration: 7\n",
      "theta_B = 0.80, theta_C = 0.52, ll = -29.90\n",
      "Iteration: 8\n",
      "theta_B = 0.80, theta_C = 0.52, ll = -29.88\n",
      "Iteration: 9\n",
      "theta_B = 0.80, theta_C = 0.52, ll = -29.87\n",
      "Iteration: 10\n",
      "theta_B = 0.80, theta_C = 0.52, ll = -29.87\n",
      "Iteration: 11\n",
      "theta_B = 0.80, theta_C = 0.52, ll = -29.87\n",
      "Iteration: 12\n",
      "theta_B = 0.80, theta_C = 0.52, ll = -29.87\n"
     ]
    }
   ],
   "source": [
    "# 观测数据，5次独立试验，每次试验10次抛掷的正反次数\n",
    "# 比如第一次试验为5次正面5次反面\n",
    "observed_data = np.array([(5,5), (9,1), (8,2), (4,6), (7,3)])\n",
    "# 初始化参数值，即硬币B的正面概率为0.6，硬币C的正面概率为0.5\n",
    "thetas = np.array([[0.6, 0.4], [0.5, 0.5]])\n",
    "thetas = em(observed_data, thetas, max_iter=30, eps=1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.7967829 , 0.2032171 ],\n",
       "       [0.51959543, 0.48040457]])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "thetas"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
